Multi-label prediction with Banknotes

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
In [2]:
import fastai
from fastai.vision import *
In [3]:
#For CPU only 
#fastai.torch_core.defaults.device = 'cpu'
#defaults.device= 'cpu'

Multiclassification

In [4]:
path = Path('/home/jupyter/.fastai/data/banknotes/')
path_imgs=path/'imgs'
path_imgs.mkdir(parents=True, exist_ok=True)
path_imgs
Out[4]:
PosixPath('/home/jupyter/.fastai/data/banknotes/imgs')
In [5]:
src = (ImageList.from_folder(path_imgs,recurse=True)
               .split_by_rand_pct(valid_pct=.2))

src
Out[5]:
ItemLists;

Train: ImageList (211 items)
Image (3, 256, 500),Image (3, 428, 500),Image (3, 255, 500),Image (3, 500, 305),Image (3, 279, 500)
Path: /home/jupyter/.fastai/data/banknotes/imgs;

Valid: ImageList (52 items)
Image (3, 355, 500),Image (3, 232, 500),Image (3, 260, 500),Image (3, 259, 500),Image (3, 472, 500)
Path: /home/jupyter/.fastai/data/banknotes/imgs;

Test: None
In [6]:
src.train.items[0]
Out[6]:
PosixPath('/home/jupyter/.fastai/data/banknotes/imgs/euro/5/IMG_20190730_232521.jpg')
In [9]:
#single class
func=lambda i: str(i.parent.relative_to(path_imgs) )
#multi class
#func=lambda i: (i.parent.relative_to(path_imgs).parts )
func(src.train.items[0])
Out[9]:
'euro/5'
In [10]:
ll = src.label_from_func(func); ll
#ll = src.label_from_folder(); ll
Out[10]:
LabelLists;

Train: LabelList (211 items)
x: ImageList
Image (3, 256, 500),Image (3, 428, 500),Image (3, 255, 500),Image (3, 500, 305),Image (3, 279, 500)
y: CategoryList
euro/5,euro/5,euro/5,euro/5,euro/5
Path: /home/jupyter/.fastai/data/banknotes/imgs;

Valid: LabelList (52 items)
x: ImageList
Image (3, 355, 500),Image (3, 232, 500),Image (3, 260, 500),Image (3, 259, 500),Image (3, 472, 500)
y: CategoryList
usd/50,usd/1,euro/20,euro/500,usd/1
Path: /home/jupyter/.fastai/data/banknotes/imgs;

Test: None
In [11]:
tfms = get_transforms(do_flip=True,flip_vert=True, 
                      max_rotate=90, 
                      max_zoom=1.5, 
                      max_lighting=0.5, 
                      max_warp=0.5)
In [12]:
#so its reproducible
#np.random.seed(42)
In [13]:
def get_data(size,bs):
    size=int(size)
    bs=int(bs)
    data = (ll.transform(tfms, size=size)
        .databunch(bs=bs) #for CPU only add ,num_workers=0
        .normalize(imagenet_stats))
    return data
size,bs=256/2,20
data=get_data(size,bs)
In [14]:
data.show_batch(rows=4, figsize=(12,9))
In [15]:
arch = models.resnet50
In [31]:
acc_02 = partial(accuracy_thresh, thresh=0.2)
f_score = partial(fbeta, thresh=0.2)
#multiclass
#learn = cnn_learner(data, arch, metrics=[acc_02, f_score])
#single class
learn = cnn_learner(data, arch, metrics=[accuracy])

We use the LR Finder to pick a good learning rate.

In [32]:
learn.lr_find()
learn.recorder.plot()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.

Then we can fit the head of our network.

In [33]:
lr = 1e-2
In [34]:
learn.fit_one_cycle(10, slice(lr),callbacks=ShowGraph(learn))
epoch train_loss valid_loss accuracy time
0 2.719794 2.119684 0.384615 00:02
1 2.264033 1.512351 0.500000 00:02
2 2.006482 1.563698 0.557692 00:02
3 1.789367 1.178731 0.615385 00:02
4 1.626022 1.039096 0.634615 00:02
5 1.461312 0.827594 0.750000 00:02
6 1.358138 0.838351 0.788462 00:02
7 1.250934 0.839516 0.750000 00:02
8 1.148749 0.775454 0.750000 00:02
9 1.095168 0.798045 0.750000 00:02
In [35]:
lr = 1e-3
learn.fit_one_cycle(5, slice(lr),callbacks=ShowGraph(learn))
epoch train_loss valid_loss accuracy time
0 0.927123 0.788109 0.769231 00:02
1 0.900647 0.745238 0.788462 00:02
2 0.860542 0.708181 0.769231 00:02
3 0.818388 0.720342 0.788462 00:02
4 0.767459 0.702057 0.807692 00:02
In [36]:
learn.show_results(rows=3)
In [37]:
learn.save('stage-1-rn50')

...And fine-tune the whole model:

In [38]:
learn.unfreeze()
In [39]:
learn.lr_find()
learn.recorder.plot()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
In [40]:
lr=1e-4
learn.fit_one_cycle(5, slice(1e-5, lr/5),callbacks=ShowGraph(learn))
epoch train_loss valid_loss accuracy time
0 0.959785 0.710335 0.807692 00:02
1 0.824092 0.682993 0.807692 00:02
2 0.760623 0.638910 0.807692 00:02
3 0.701415 0.688630 0.807692 00:02
4 0.694100 0.682584 0.788462 00:02
In [41]:
learn.save('stage-2-rn50')
In [42]:
learn.fit_one_cycle(5, slice(1e-5, lr/5),callbacks=ShowGraph(learn))
epoch train_loss valid_loss accuracy time
0 0.754785 0.674954 0.788462 00:02
1 0.717118 0.701267 0.769231 00:02
2 0.666054 0.700104 0.769231 00:02
3 0.610322 0.671121 0.769231 00:02
4 0.595243 0.676425 0.807692 00:02
In [43]:
learn.show_results(rows=3)
In [44]:
gc.collect()
torch.cuda.empty_cache()
In [45]:
size,bs=256,10/4
data=get_data(size,bs)
In [46]:
learn.freeze()
In [47]:
learn.lr_find()
learn.recorder.plot()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
In [48]:
lr=1e-3
In [49]:
learn.fit_one_cycle(5, slice(lr),callbacks=ShowGraph(learn))
epoch train_loss valid_loss accuracy time
0 0.617516 0.723374 0.769231 00:05
1 0.654863 0.724263 0.826923 00:04
2 0.622503 0.721524 0.807692 00:04
3 0.587861 0.733027 0.788462 00:04
4 0.557029 0.735960 0.788462 00:04
In [50]:
learn.fit_one_cycle(5, slice(lr),callbacks=ShowGraph(learn))
epoch train_loss valid_loss accuracy time
0 0.620950 0.716527 0.807692 00:03
1 0.572661 0.676043 0.807692 00:04
2 0.577830 0.702346 0.807692 00:04
3 0.527988 0.676882 0.807692 00:04
4 0.500437 0.687347 0.807692 00:04
In [51]:
learn.save('stage-1-256-rn50')
In [52]:
learn.show_results()
In [53]:
learn.unfreeze()
In [54]:
learn.fit_one_cycle(5, slice(1e-5, lr/5),callbacks=ShowGraph(learn))
epoch train_loss valid_loss accuracy time
0 0.443138 0.687663 0.807692 00:05
1 0.450793 0.661966 0.807692 00:04
2 0.399322 0.596265 0.807692 00:04
3 0.403629 0.525856 0.807692 00:04
4 0.375984 0.510671 0.807692 00:04
In [55]:
learn.save('stage-2-256-rn50')
In [56]:
learn.export()
In [ ]:
path
In [ ]:
# If using single class
In [57]:
interp = ClassificationInterpretation.from_learner(learn)
In [58]:
interp.plot_top_losses(9, figsize=(15,11),heatmap=True)
In [59]:
interp.plot_confusion_matrix()
In [ ]: